import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from generate_T import find_invertible_submatrix, Generator_matrix
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

class MyModel(nn.Module):
    def __init__(self, embed_size,inport_length,data_length,dropout_prob=0.2):
        super(MyModel, self).__init__()
        size, new_size,l, m ,= 9, 9, 3, 3
        T = Generator_matrix(size, l, m)
        sub_T = find_invertible_submatrix(T, new_size)
        self.T = torch.tensor(sub_T, dtype=torch.float32).to(device)
        np.savetxt('sub_T.txt', sub_T, fmt='%d')
        self.embedding1 = nn.Embedding(num_embeddings=inport_length, embedding_dim=embed_size)               
        self.relu = nn.Tanh()
        
        self.pool_1= nn.MaxPool2d(kernel_size=2, stride=1)
        self.conv1_1 = nn.Conv2d(1, 4, kernel_size=3, stride=1, padding=1)
        self.conv2_1 = nn.Conv2d(4,8,kernel_size=3,stride=1,padding=1)
        self.conv3_1 = nn.Conv2d(8,16,kernel_size=3,stride=1,padding=1)
        
        self.pool_2= nn.MaxPool2d(kernel_size=2,stride=1)
        self.conv1_2 = nn.Conv2d(1,4,kernel_size=3,stride=1,padding=1)
        self.conv2_2= nn.Conv2d(4,8,kernel_size=3,stride=1,padding=1)
        self.conv3_2= nn.Conv2d(8,16,kernel_size=3,stride=1,padding=1)
        
        self.fusion_weight = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
        self.fc = nn.Sequential(
                                nn.Linear(2304, 1024),
                                nn.ReLU(),
                                nn.Dropout(p=dropout_prob),
                                nn.Linear(1024, 256),
                                nn.ReLU(),
                                nn.Dropout(p=dropout_prob),
                                nn.Linear(256, 1)
                                )
        self.fc1 = nn.Sequential(
                                nn.Linear(1152, 8),
                                nn.ReLU(),
                                nn.Dropout(p=dropout_prob),
                                nn.Linear(8, 1)
                                )
        self.dropout = nn.Dropout(dropout_prob)                        
    
    def forward(self, x):
        E1 = self.embedding1(x)
        R = torch.flatten(E1,1)
        R = self.fc(R)
        A1= torch.matmul(E1, E1.transpose(-2, -1))
        
        AA = torch.matmul(self.T, A1)
        AA = torch.matmul(self.T, AA.transpose(-2, -1))
        
        A = self.relu(A1).unsqueeze(1)
        AA = self.relu(AA).unsqueeze(1)

        b, c, h, w = A.shape
        b_, c_, h_, w_ = A.shape
        
        x_tmp = F.relu(self.conv1_1(A))
        x_tmp = self.pool_1(x_tmp)
        x_tmp = F.relu(self.conv2_1(x_tmp))
        x_tmp = self.pool_1(x_tmp)
        x_tmp = F.relu(self.conv3_1(x_tmp))
        x_tmp = self.pool_1(x_tmp)
        
        x_tmp_ = F.relu(self.conv1_2(AA))
        x_tmp_ = self.pool_2(x_tmp_)
        x_tmp_ = F.relu(self.conv2_2(x_tmp_))
        x_tmp_ = self.pool_2(x_tmp_)
        x_tmp_ = F.relu(self.conv3_2(x_tmp_))
        x_tmp_ = self.pool_2(x_tmp_)

        
        x_concat = torch.cat((x_tmp, x_tmp_), dim=1)
        C = x_concat.reshape(b, -1)
        C = self.fc1(C)
        fusion_weight = torch.sigmoid(self.fusion_weight)
        out = fusion_weight * C + (1 - fusion_weight) * R
        return out